package me.zhengjie.app.action.interceptor;

import cn.hutool.crypto.digest.DigestUtil;
import com.google.common.base.Charsets;
import com.google.common.collect.Maps;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import me.zhengjie.app.action.annotation.PassApiAuth;
import me.zhengjie.exception.ServerException;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.stereotype.Component;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.handler.HandlerInterceptorAdapter;

import javax.annotation.PostConstruct;
import javax.servlet.*;
import javax.servlet.annotation.WebFilter;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;

/**
 * 接口安全拦截器
 */
@Slf4j
@Component
public class ApiAuthenticationInterceptor extends HandlerInterceptorAdapter {
    private final Map<String, String> API_KEYS_MAPPING = Maps.newConcurrentMap();
    @Autowired
    private ApiAuthConfig apiAuthConfig;
    /**
     * header字段
     */
    public static final String HEADER_HANDLE_TYPE_FIELD = "handleType";
    /**
     * header的值
     */
    public static final String HEADER_VALUE_NOT_VALIDATE_BODY = "NOT_VALIDATE_BODY";

    /**
     * “&” 数据拼接分割符号
     */
    public static final String PARTING_TAG = "&";

    public ApiAuthenticationInterceptor() { }

    @PostConstruct
    private void init() {
        if (!StringUtils.isAllBlank(apiAuthConfig.getAppId(), apiAuthConfig.getAppKey())) {
            API_KEYS_MAPPING.put(apiAuthConfig.getAppId(), apiAuthConfig.getAppKey());
            Collections.unmodifiableMap(API_KEYS_MAPPING);
        }

    }

    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        if (Objects.isNull(apiAuthConfig.getEnabled())
                || !apiAuthConfig.getEnabled()) {
            return true;
        }
        if (apiAuthConfig.inWhitelists(request.getRequestURI())) {
            return true;
        }

        if (!RequestMethod.OPTIONS.name().equals(request.getMethod())) {
            PassApiAuth annotation;
            if (handler instanceof HandlerMethod) {
                annotation = ((HandlerMethod) handler).getMethodAnnotation(PassApiAuth.class);
                if (Objects.nonNull(annotation)) {
                    return true;
                }
            }
            String requestBody = getRequestBody(request);

            String appId = request.getHeader("appId");
            String timestamp = request.getHeader("timestamp");
            String appSign = request.getHeader("appSign");

            if (StringUtils.isAnyBlank(appId, timestamp, appSign)) {
                log.warn("API认证参数不完整 [appId:{} , timestamp:{} , appSign:{}]", appId, timestamp, appSign);
                throw new ServerException("接口认证参数不完整");
            }

            if (Long.compare(System.currentTimeMillis(), Long.valueOf(timestamp) + apiAuthConfig.getTimeOut() * 1000) > 0) {
                log.warn("API认证参数过期 [appId:{} , timestamp:{} , appSign:{}]", appId, timestamp, appSign);
                throw new ServerException("接口认证参数过期");
            }

            String appKey = API_KEYS_MAPPING.get(appId);

            String digestedSign = DigestUtil.sha1Hex(appId + PARTING_TAG +
                    timestamp + PARTING_TAG +
                    appKey + PARTING_TAG + requestBody);

            if (!StringUtils.equals(digestedSign, appSign)) {
                log.warn("非法签名,数据可能被非法篡改[appId:{} , timestamp:{} , appSign:{} , appKey:{} , digestedSign:{}]", appId, timestamp, appSign, appKey, digestedSign);
                throw new ServerException("非法签名,数据可能被非法篡改");
            }

        }

        return super.preHandle(request, response, handler);
    }

    private String getRequestBody(HttpServletRequest request) {
        String handleType = request.getHeader(HEADER_HANDLE_TYPE_FIELD);
        if (HEADER_VALUE_NOT_VALIDATE_BODY.equals(handleType)) {
            return "";
        }

        if (!"POST".equals(request.getMethod())) {
            return "";
        }
        String requestBody = null;
        try {

            Assert.isInstanceOf(ContentCachingRequestWrapperFilter.RequestBodyCachingHttpServletRequestWrapper.class, request);
            requestBody = new String(IOUtils.toByteArray(request.getInputStream()), Charsets.UTF_8);
        } catch (Exception e) {
            log.warn("[获取requestBody出现异常]", e);
            return "";
        }
        return requestBody;
    }

    @Component
    @WebFilter(filterName = "ContentCachingRequestWrapperFilter", urlPatterns = "*")
    public static class ContentCachingRequestWrapperFilter implements Filter {
        @Autowired
        private ApiAuthConfig apiAuthConfig;

        @Override
        public void init(FilterConfig filterConfig) throws ServletException {
        }

        @Override
        public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
                throws IOException, ServletException {
            String handleType = ((HttpServletRequest) request).getHeader(HEADER_HANDLE_TYPE_FIELD);
            if (HEADER_VALUE_NOT_VALIDATE_BODY.equals(handleType)) {
                chain.doFilter(request, response);
            } else if (Objects.isNull(apiAuthConfig.getEnabled())
                    || !apiAuthConfig.getEnabled()) {
                chain.doFilter(request, response);
            } else if (Objects.nonNull(request.getContentType()) &&
                    request.getContentType().startsWith("application/json")) {
                chain.doFilter(
                        new RequestBodyCachingHttpServletRequestWrapper((HttpServletRequest) request),
                        response);

            } else {
                chain.doFilter(request, response);
            }

        }

        @Override
        public void destroy() {
        }

        @Slf4j
        public static class RequestBodyCachingHttpServletRequestWrapper extends HttpServletRequestWrapper {

            private byte[] requestBody;

            public RequestBodyCachingHttpServletRequestWrapper(HttpServletRequest request) {
                super(request);
                try {
                    requestBody = IOUtils.toByteArray(request.getInputStream());
                } catch (IOException e) {
                    log.error(e.getMessage(), e);
                    requestBody = new byte[1];
                }
            }

            @Override
            public ServletInputStream getInputStream() throws IOException {
                return new ContentCachingInputStream(requestBody);
            }

            @Override
            public String getCharacterEncoding() {
                String enc = super.getCharacterEncoding();
                return (enc != null ? enc : Charsets.UTF_8.name());
            }

            public byte[] getContentAsByteArray() {
                return requestBody;
            }

            public void setContent(byte[] content) {
                requestBody = content;
            }

            private class ContentCachingInputStream extends ServletInputStream {

                final ByteArrayInputStream bais;

                public ContentCachingInputStream(byte[] requestBody) {
                    this.bais = new ByteArrayInputStream(requestBody);
                }

                @Override
                public int read() throws IOException {
                    return bais.read();
                }

                @Override
                public boolean isFinished() {
                    return false;
                }

                @Override
                public boolean isReady() {
                    return false;
                }

                @Override
                public void setReadListener(ReadListener readListener) {
                }
            }
        }

    }

    /**
     * Api认证配置
     *
     * @author: zet
     * @date: 2018/7/19 8:06
     */
    @Component
    @ConfigurationProperties(prefix = "api-auth")
    @Data
    public static class ApiAuthConfig {
        /**
         * 是否开启拦截（默认是）
         */
        private Boolean enabled;

        /**
         * 应用Id
         */
        private String appId;

        /**
         * 应用密码
         */
        private String appKey;

        /**
         * 密文有效期（秒）
         */
        private Integer timeOut;

        /**
         * 白名单url
         */
        private List<String> whitelists;


        public boolean inWhitelists(String url) {
            if (CollectionUtils.isEmpty(whitelists)) {
                return true;
            }
            for (String whitelist : whitelists) {
                if (whitelist.contains(url)) {
                    return true;
                }
            }
            return false;
        }
    }
}
